import dgl.nn
import torch.nn.functional as F
import torch
import torch.nn as nn
import torch.nn.functional as F
from dgl.nn.pytorch import GraphConv,GATConv,GATv2Conv,SAGEConv
import dgl.nn.pytorch as dglnn
import dgl.nn.functional as fnl
import dgl.function as fn
class GCN(nn.Module):
    def __init__(self,
                 config,
                 num_layers,
                 src,
                 dst):
        super(GCN, self).__init__()
        # 将user_id和item_id输入，分别是src和dst。并用他们构成一个双向图
        self.src=src
        self.dst=dst
        self.g=self.create_graph()
        self.latent_dim=config['embedding_size']
        self.layers = nn.ModuleList([GraphConv(self.latent_dim,self.latent_dim) for _ in range(num_layers)])
    def create_graph(self):
        # 构建双向图并添加自环
        g = dgl.graph((self.src, self.dst), num_nodes=self.num_users + self.num_items)
        g = dgl.to_bidirected(g)
        g = dgl.add_self_loop(g)
        g = g.to(self.device)
        return g
    def forward(self, h):
        self.g.ndata['h']=h

        for layer in self.layers:
            h=layer(self.g,h)
        return h
class LightGCN(nn.Module):
    def __init__(self,
                 num_users,
                 num_items,
                 embedding_dim,
                 num_layers
                 ):
        super(LightGCN, self).__init__()
        self.num_users = num_users
        self.num_items = num_items
        self.embedding_dim = embedding_dim
        self.num_layers = num_layers
        # 构建图
        # 定义图卷积层（LightGCN 中的 GraphConv 仅进行消息传递，不进行线性变换）
        self.layers = nn.ModuleList([GraphConv(embedding_dim,
                                               embedding_dim,
                                               norm='both',
                                               weight=False,
                                               bias=False) for _ in range(num_layers)])
    def forward(self,h,g):
        all_embeddings = [h]
        # 图归一化已在 GraphConv 中处理，通过 norm='both'
        # 多层嵌入传播
        for layer in self.layers:
            embeddings = layer(g, h)
            all_embeddings.append(embeddings)
        # 聚合各层嵌入（均值）
        final_embeddings = torch.stack(all_embeddings, dim=1).mean(dim=1)  # (num_users + num_items, embedding_dim)
        # 分离用户和物品嵌入
        user_embeddings, item_embeddings = torch.split(final_embeddings, [self.num_users, self.num_items], dim=0)
        return user_embeddings, item_embeddings
class GAT(nn.Module):
    def __init__(self,
                 latent_dim,
                 num_layers,
                 num_heads,
                 src,
                 dst):
        super(GAT, self).__init__()
        # 将user_id和item_id输入，分别是src和dst。并用他们构成一个双向图
        self.src=src
        self.dst=dst
        self.g=self.create_graph()
        self.latent_dim=latent_dim
        self.num_heads=num_heads
        self.layers = nn.ModuleList([GATConv(self.latent_dim,self.latent_dim,num_heads=self.num_heads) for _ in range(num_layers)])

    def create_graph(self):
        g=dgl.to_bidirected(dgl.graph((self.src,self.dst)))
        g=dgl.add_self_loop(g).to('cuda')
        return g
    def forward(self, h):
        self.g.ndata['h']=h
        for layer in self.layers:
            h=layer(self.g,h)
            h=torch.mean(h, dim=1)
        return h

class GATv2(nn.Module):
    def __init__(self, latent_dim, num_layers,num_heads,src,dst):
        super(GATv2, self).__init__()
        # 将user_id和item_id输入，分别是src和dst。并用他们构成一个双向图
        self.src=src
        self.dst=dst
        self.g=self.create_graph()
        self.latent_dim=latent_dim
        self.num_heads=num_heads
        self.layers = nn.ModuleList([GATv2Conv(self.latent_dim,self.latent_dim,num_heads=self.num_heads) for _ in range(num_layers)])

    def create_graph(self):
        g=dgl.to_bidirected(dgl.graph((self.src,self.dst)))
        g=dgl.add_self_loop(g).to('cuda')
        return g
    def forward(self, h):
        self.g.ndata['h']=h
        for layer in self.layers:
            h=layer(self.g,h)
            h=torch.mean(h, dim=1)
        return h

class GraphSAGE(nn.Module):
    def __init__(self, latent_dim, num_layers,src,dst):
        super(GraphSAGE, self).__init__()
        # 将user_id和item_id输入，分别是src和dst。并用他们构成一个双向图
        self.src=src
        self.dst=dst
        self.g=self.create_graph()
        self.latent_dim=latent_dim
        self.layers = nn.ModuleList([SAGEConv(self.latent_dim,self.latent_dim,'mean') for _ in range(num_layers)])

    def create_graph(self):
        g=dgl.to_bidirected(dgl.graph((self.src,self.dst)))
        g=dgl.add_self_loop(g).to('cuda')
        return g
    def forward(self, h):
        self.g.ndata['h']=h
        for layer in self.layers:
            h=layer(self.g,h)
        return h